import numpy as np


class RLDP:
    def __init__(self, num_states, num_actions, dynamics, rewards):
        self.num_states = num_states
        self.num_actions = num_actions

        self.V = np.zeros(num_states)
        self.Q = np.zeros((num_states, num_actions))
        self.r = rewards.copy()
        self.dynamics = dynamics.copy()
        # policy = np.zeros((num_states,num_actions))

    def value_iteration(self):
        num_iterations = 1000
        gamma = 0.9

        for t in range(num_iterations):
            if(t%100==0):
                print(t)
            for s in range(self.num_states):
                for a in range(self.num_actions):
                    self.Q[s, a] = np.sum(self.dynamics[s,a]*self.V)
                    # for next_s in range(self.num_states):
                    #     self.Q[s,a] += self.dynamics[s,a,next_s]*self.V[next_s]
                    if(s == 9):
                        self.Q[s,a] = self.r[s,a]
                    else:
                        self.Q[s,a] = self.r[s,a] + gamma * self.Q[s,a]

            for s in range(self.num_states):
                self.V[s] = np.max(self.Q[s])
        # import matplotlib.pyplot as plt
        # plt.imshow(np.max(self.Q,axis=1).reshape((-1,10)))
        # plt.show()
        # print(self.V[0])
        return np.argmax(self.Q,axis=1), self.Q
